{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fluF3_oOgkWF"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "AJs7HHFmg1M9"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jYysdyb-CaWM"
      },
      "source": [
        "# Simple audio recognition: Recognizing keywords"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CNbqmZy0gbyE"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/audio/simple_audio\">\n",
        "    <img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />\n",
        "    View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/audio/simple_audio.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />\n",
        "    Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/audio/simple_audio.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />\n",
        "    View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/audio/simple_audio.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SPfDNFlb66XF"
      },
      "source": [
        "This tutorial will show you how to build a basic speech recognition network that recognizes ten different words. It's important to know that real speech and audio recognition systems are much more complex, but like MNIST for images, it should give you a basic understanding of the techniques involved. Once you've completed this tutorial, you'll have a model that tries to classify a one second audio clip as \"down\", \"go\", \"left\", \"no\", \"right\", \"stop\", \"up\" and \"yes\"."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Go9C3uLL8Izc"
      },
      "source": [
        "## Setup\n",
        "\n",
        "Import necessary modules and dependencies."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dzLKpmZICaWN"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import pathlib\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import seaborn as sns\n",
        "import tensorflow as tf\n",
        "\n",
        "from tensorflow.keras.layers.experimental import preprocessing\n",
        "from tensorflow.keras import layers\n",
        "from tensorflow.keras import models\n",
        "from IPython import display\n",
        "\n",
        "\n",
        "# Set seed for experiment reproducibility\n",
        "seed = 42\n",
        "tf.random.set_seed(seed)\n",
        "np.random.seed(seed)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yR0EdgrLCaWR"
      },
      "source": [
        "## Import the Speech Commands dataset\n",
        "\n",
        "You'll write a script to download a portion of the [Speech Commands dataset](https://www.tensorflow.org/datasets/catalog/speech_commands). The original dataset consists of over 105,000 WAV audio files of people saying thirty different words. This data was collected by Google and released under a CC BY license.\n",
        "\n",
        "You'll be using a portion of the dataset to save time with data loading. Extract the `mini_speech_commands.zip` and load it in using the `tf.data` API."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2-rayb7-3Y0I"
      },
      "outputs": [],
      "source": [
        "data_dir = pathlib.Path('data/mini_speech_commands')\n",
        "if not data_dir.exists():\n",
        "  tf.keras.utils.get_file(\n",
        "      'mini_speech_commands.zip',\n",
        "      origin=\"http://storage.googleapis.com/download.tensorflow.org/data/mini_speech_commands.zip\",\n",
        "      extract=True,\n",
        "      cache_dir='.', cache_subdir='data')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BgvFq3uYiS5G"
      },
      "source": [
        "Check basic statistics about the dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "70IBxSKxA1N9"
      },
      "outputs": [],
      "source": [
        "commands = np.array(tf.io.gfile.listdir(str(data_dir)))\n",
        "commands = commands[commands != 'README.md']\n",
        "print('Commands:', commands)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aMvdU9SY8WXN"
      },
      "source": [
        "Extract the audio files into a list and shuffle it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hlX685l1wD9k"
      },
      "outputs": [],
      "source": [
        "filenames = tf.io.gfile.glob(str(data_dir) + '/*/*')\n",
        "filenames = tf.random.shuffle(filenames)\n",
        "num_samples = len(filenames)\n",
        "print('Number of total examples:', num_samples)\n",
        "print('Number of examples per label:',\n",
        "      len(tf.io.gfile.listdir(str(data_dir/commands[0]))))\n",
        "print('Example file tensor:', filenames[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9vK3ymy23MCP"
      },
      "source": [
        "Split the files into training, validation and test sets using a 80:10:10 ratio, respectively."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Cv_wts-l3KgD"
      },
      "outputs": [],
      "source": [
        "train_files = filenames[:6400]\n",
        "val_files = filenames[6400: 6400 + 800]\n",
        "test_files = filenames[-800:]\n",
        "\n",
        "print('Training set size', len(train_files))\n",
        "print('Validation set size', len(val_files))\n",
        "print('Test set size', len(test_files))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g2Cj9FyvfweD"
      },
      "source": [
        "## Reading audio files and their labels"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j1zjcWteOcBy"
      },
      "source": [
        "The audio file will initially be read as a binary file, which you'll want to convert into a numerical tensor.\n",
        "\n",
        "To load an audio file, you will use [`tf.audio.decode_wav`](https://www.tensorflow.org/api_docs/python/tf/audio/decode_wav), which returns the WAV-encoded audio as a Tensor and the sample rate.\n",
        "\n",
        "A WAV file contains time series data with a set number of samples per second. \n",
        "Each sample represents the amplitude of the audio signal at that specific time. In a 16-bit system, like the files in `mini_speech_commands`, the values range from -32768 to 32767. \n",
        "The sample rate for this dataset is 16kHz.\n",
        "Note that `tf.audio.decode_wav` will normalize the values to the range [-1.0, 1.0]."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9PjJ2iXYwftD"
      },
      "outputs": [],
      "source": [
        "def decode_audio(audio_binary):\n",
        "  audio, _ = tf.audio.decode_wav(audio_binary)\n",
        "  return tf.squeeze(audio, axis=-1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GPQseZElOjVN"
      },
      "source": [
        "The label for each WAV file is its parent directory."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8VTtX1nr3YT-"
      },
      "outputs": [],
      "source": [
        "def get_label(file_path):\n",
        "  parts = tf.strings.split(file_path, os.path.sep)\n",
        "\n",
        "  # Note: You'll use indexing here instead of tuple unpacking to enable this \n",
        "  # to work in a TensorFlow graph.\n",
        "  return parts[-2] "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E8Y9w_5MOsr-"
      },
      "source": [
        "Let's define a method that will take in the filename of the WAV file and output a tuple containing the audio and labels for supervised training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WdgUD5T93NyT"
      },
      "outputs": [],
      "source": [
        "def get_waveform_and_label(file_path):\n",
        "  label = get_label(file_path)\n",
        "  audio_binary = tf.io.read_file(file_path)\n",
        "  waveform = decode_audio(audio_binary)\n",
        "  return waveform, label"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nvN8W_dDjYjc"
      },
      "source": [
        "You will now apply `process_path` to build your training set to extract the audio-label pairs and check the results. You'll build the validation and test sets using a similar procedure later on."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0SQl8yXl3kNP"
      },
      "outputs": [],
      "source": [
        "AUTOTUNE = tf.data.AUTOTUNE\n",
        "files_ds = tf.data.Dataset.from_tensor_slices(train_files)\n",
        "waveform_ds = files_ds.map(get_waveform_and_label, num_parallel_calls=AUTOTUNE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "voxGEwvuh2L7"
      },
      "source": [
        "Let's examine a few audio waveforms with their corresponding labels."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8yuX6Nqzf6wT"
      },
      "outputs": [],
      "source": [
        "rows = 3\n",
        "cols = 3\n",
        "n = rows*cols\n",
        "fig, axes = plt.subplots(rows, cols, figsize=(10, 12))\n",
        "for i, (audio, label) in enumerate(waveform_ds.take(n)):\n",
        "  r = i // cols\n",
        "  c = i % cols\n",
        "  ax = axes[r][c]\n",
        "  ax.plot(audio.numpy())\n",
        "  ax.set_yticks(np.arange(-1.2, 1.2, 0.2))\n",
        "  label = label.numpy().decode('utf-8')\n",
        "  ax.set_title(label)\n",
        "\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EWXPphxm0B4m"
      },
      "source": [
        "## Spectrogram\n",
        "\n",
        "You'll convert the waveform into a spectrogram, which shows frequency changes over time and can be represented as a 2D image. This can be done by applying the short-time Fourier transform (STFT) to convert the audio into the time-frequency domain.\n",
        "\n",
        "A Fourier transform ([`tf.signal.fft`](https://www.tensorflow.org/api_docs/python/tf/signal/fft)) converts a signal to its component frequencies, but loses all time information. The STFT ([`tf.signal.stft`](https://www.tensorflow.org/api_docs/python/tf/signal/stft)) splits the signal into windows of time and runs a Fourier transform on each window, preserving some time information, and returning a 2D tensor that you can run standard convolutions on.\n",
        "\n",
        "STFT produces an array of complex numbers representing magnitude and phase. However, you'll only need the magnitude for this tutorial, which can be derived by applying `tf.abs` on the output of `tf.signal.stft`. \n",
        "\n",
        "Choose `frame_length` and `frame_step` parameters such that the generated spectrogram \"image\" is almost square. For more information on STFT parameters choice, you can refer to [this video](https://www.coursera.org/lecture/audio-signal-processing/stft-2-tjEQe) on audio signal processing. \n",
        "\n",
        "You also want the waveforms to have the same length, so that when you convert it to a spectrogram image, the results will have similar dimensions. This can be done by simply zero padding the audio clips that are shorter than one second.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_4CK75DHz_OR"
      },
      "outputs": [],
      "source": [
        "def get_spectrogram(waveform):\n",
        "  # Padding for files with less than 16000 samples\n",
        "  zero_padding = tf.zeros([16000] - tf.shape(waveform), dtype=tf.float32)\n",
        "\n",
        "  # Concatenate audio with padding so that all audio clips will be of the \n",
        "  # same length\n",
        "  waveform = tf.cast(waveform, tf.float32)\n",
        "  equal_length = tf.concat([waveform, zero_padding], 0)\n",
        "  spectrogram = tf.signal.stft(\n",
        "      equal_length, frame_length=255, frame_step=128)\n",
        "      \n",
        "  spectrogram = tf.abs(spectrogram)\n",
        "\n",
        "  return spectrogram"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5rdPiPYJphs2"
      },
      "source": [
        "Next, you will explore the data. Compare the waveform, the spectrogram and the actual audio of one example from the dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4Mu6Y7Yz3C-V"
      },
      "outputs": [],
      "source": [
        "for waveform, label in waveform_ds.take(1):\n",
        "  label = label.numpy().decode('utf-8')\n",
        "  spectrogram = get_spectrogram(waveform)\n",
        "\n",
        "print('Label:', label)\n",
        "print('Waveform shape:', waveform.shape)\n",
        "print('Spectrogram shape:', spectrogram.shape)\n",
        "print('Audio playback')\n",
        "display.display(display.Audio(waveform, rate=16000))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e62jzb36-Jog"
      },
      "outputs": [],
      "source": [
        "def plot_spectrogram(spectrogram, ax):\n",
        "  # Convert to frequencies to log scale and transpose so that the time is\n",
        "  # represented in the x-axis (columns).\n",
        "  log_spec = np.log(spectrogram.T)\n",
        "  height = log_spec.shape[0]\n",
        "  width = log_spec.shape[1]\n",
        "  X = np.linspace(0, np.size(spectrogram), num=width, dtype=int)\n",
        "  Y = range(height)\n",
        "  ax.pcolormesh(X, Y, log_spec)\n",
        "\n",
        "\n",
        "fig, axes = plt.subplots(2, figsize=(12, 8))\n",
        "timescale = np.arange(waveform.shape[0])\n",
        "axes[0].plot(timescale, waveform.numpy())\n",
        "axes[0].set_title('Waveform')\n",
        "axes[0].set_xlim([0, 16000])\n",
        "plot_spectrogram(spectrogram.numpy(), axes[1])\n",
        "axes[1].set_title('Spectrogram')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GyYXjW07jCHA"
      },
      "source": [
        "Now transform the waveform dataset to have spectrogram images and their corresponding labels as integer IDs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "43IS2IouEV40"
      },
      "outputs": [],
      "source": [
        "def get_spectrogram_and_label_id(audio, label):\n",
        "  spectrogram = get_spectrogram(audio)\n",
        "  spectrogram = tf.expand_dims(spectrogram, -1)\n",
        "  label_id = tf.argmax(label == commands)\n",
        "  return spectrogram, label_id"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yEVb_oK0oBLQ"
      },
      "outputs": [],
      "source": [
        "spectrogram_ds = waveform_ds.map(\n",
        "    get_spectrogram_and_label_id, num_parallel_calls=AUTOTUNE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6gQpAAgMnyDi"
      },
      "source": [
        "Examine the spectrogram \"images\" for different samples of the dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QUbHfTuon4iF"
      },
      "outputs": [],
      "source": [
        "rows = 3\n",
        "cols = 3\n",
        "n = rows*cols\n",
        "fig, axes = plt.subplots(rows, cols, figsize=(10, 10))\n",
        "for i, (spectrogram, label_id) in enumerate(spectrogram_ds.take(n)):\n",
        "  r = i // cols\n",
        "  c = i % cols\n",
        "  ax = axes[r][c]\n",
        "  plot_spectrogram(np.squeeze(spectrogram.numpy()), ax)\n",
        "  ax.set_title(commands[label_id.numpy()])\n",
        "  ax.axis('off')\n",
        "  \n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z5KdY8IF8rkt"
      },
      "source": [
        "## Build and train the model\n",
        "\n",
        "Now you can build and train your model. But before you do that, you'll need to repeat the training set preprocessing on the validation and test sets."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "10UI32QH_45b"
      },
      "outputs": [],
      "source": [
        "def preprocess_dataset(files):\n",
        "  files_ds = tf.data.Dataset.from_tensor_slices(files)\n",
        "  output_ds = files_ds.map(get_waveform_and_label, num_parallel_calls=AUTOTUNE)\n",
        "  output_ds = output_ds.map(\n",
        "      get_spectrogram_and_label_id,  num_parallel_calls=AUTOTUNE)\n",
        "  return output_ds"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HNv4xwYkB2P6"
      },
      "outputs": [],
      "source": [
        "train_ds = spectrogram_ds\n",
        "val_ds = preprocess_dataset(val_files)\n",
        "test_ds = preprocess_dataset(test_files)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "assnWo6SB3lR"
      },
      "source": [
        "Batch the training and validation sets for model training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UgY9WYzn61EX"
      },
      "outputs": [],
      "source": [
        "batch_size = 64\n",
        "train_ds = train_ds.batch(batch_size)\n",
        "val_ds = val_ds.batch(batch_size)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GS1uIh6F_TN9"
      },
      "source": [
        "Add dataset [`cache()`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#cache) and [`prefetch()`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#prefetch) operations to reduce read latency while training the model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fdZ6M-F5_QzY"
      },
      "outputs": [],
      "source": [
        "train_ds = train_ds.cache().prefetch(AUTOTUNE)\n",
        "val_ds = val_ds.cache().prefetch(AUTOTUNE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rwHkKCQQb5oW"
      },
      "source": [
        "For the model, you'll use a simple convolutional neural network (CNN), since you have transformed the audio files into spectrogram images.\n",
        "The model also has the following additional preprocessing layers:\n",
        "- A [`Resizing`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing/Resizing) layer to downsample the input to enable the model to train faster.\n",
        "- A [`Normalization`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing/Normalization) layer to normalize each pixel in the image based on its mean and standard deviation.\n",
        "\n",
        "For the `Normalization` layer, its `adapt` method would first need to be called on the training data in order to compute aggregate statistics (i.e. mean and standard deviation)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ALYz7PFCHblP"
      },
      "outputs": [],
      "source": [
        "for spectrogram, _ in spectrogram_ds.take(1):\n",
        "  input_shape = spectrogram.shape\n",
        "print('Input shape:', input_shape)\n",
        "num_labels = len(commands)\n",
        "\n",
        "norm_layer = preprocessing.Normalization()\n",
        "norm_layer.adapt(spectrogram_ds.map(lambda x, _: x))\n",
        "\n",
        "model = models.Sequential([\n",
        "    layers.Input(shape=input_shape),\n",
        "    preprocessing.Resizing(32, 32), \n",
        "    norm_layer,\n",
        "    layers.Conv2D(32, 3, activation='relu'),\n",
        "    layers.Conv2D(64, 3, activation='relu'),\n",
        "    layers.MaxPooling2D(),\n",
        "    layers.Dropout(0.25),\n",
        "    layers.Flatten(),\n",
        "    layers.Dense(128, activation='relu'),\n",
        "    layers.Dropout(0.5),\n",
        "    layers.Dense(num_labels),\n",
        "])\n",
        "\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wFjj7-EmsTD-"
      },
      "outputs": [],
      "source": [
        "model.compile(\n",
        "    optimizer=tf.keras.optimizers.Adam(),\n",
        "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    metrics=['accuracy'],\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ttioPJVMcGtq"
      },
      "outputs": [],
      "source": [
        "EPOCHS = 10\n",
        "history = model.fit(\n",
        "    train_ds, \n",
        "    validation_data=val_ds,  \n",
        "    epochs=EPOCHS,\n",
        "    callbacks=tf.keras.callbacks.EarlyStopping(verbose=1, patience=2),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gjpCDeQ4mUfS"
      },
      "source": [
        "Let's check the training and validation loss curves to see how your model has improved during training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nzhipg3Gu2AY"
      },
      "outputs": [],
      "source": [
        "metrics = history.history\n",
        "plt.plot(history.epoch, metrics['loss'], metrics['val_loss'])\n",
        "plt.legend(['loss', 'val_loss'])\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5ZTt3kO3mfm4"
      },
      "source": [
        "## Evaluate test set performance\n",
        "\n",
        "Let's run the model on the test set and check performance."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "biU2MwzyAo8o"
      },
      "outputs": [],
      "source": [
        "test_audio = []\n",
        "test_labels = []\n",
        "\n",
        "for audio, label in test_ds:\n",
        "  test_audio.append(audio.numpy())\n",
        "  test_labels.append(label.numpy())\n",
        "\n",
        "test_audio = np.array(test_audio)\n",
        "test_labels = np.array(test_labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ktUanr9mRZky"
      },
      "outputs": [],
      "source": [
        "y_pred = np.argmax(model.predict(test_audio), axis=1)\n",
        "y_true = test_labels\n",
        "\n",
        "test_acc = sum(y_pred == y_true) / len(y_true)\n",
        "print(f'Test set accuracy: {test_acc:.0%}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "en9Znt1NOabH"
      },
      "source": [
        "### Display a confusion matrix\n",
        "\n",
        "A confusion matrix is helpful to see how well the model did on each of the commands in the test set."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LvoSAOiXU3lL"
      },
      "outputs": [],
      "source": [
        "confusion_mtx = tf.math.confusion_matrix(y_true, y_pred) \n",
        "plt.figure(figsize=(10, 8))\n",
        "sns.heatmap(confusion_mtx, xticklabels=commands, yticklabels=commands, \n",
        "            annot=True, fmt='g')\n",
        "plt.xlabel('Prediction')\n",
        "plt.ylabel('Label')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mQGi_mzPcLvl"
      },
      "source": [
        "## Run inference on an audio file\n",
        "\n",
        "Finally, verify the model's prediction output using an input audio file of someone saying \"no.\" How well does your model perform?"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zRxauKMdhofU"
      },
      "outputs": [],
      "source": [
        "sample_file = data_dir/'no/01bb6a2a_nohash_0.wav'\n",
        "\n",
        "sample_ds = preprocess_dataset([str(sample_file)])\n",
        "\n",
        "for spectrogram, label in sample_ds.batch(1):\n",
        "  prediction = model(spectrogram)\n",
        "  plt.bar(commands, tf.nn.softmax(prediction[0]))\n",
        "  plt.title(f'Predictions for \"{commands[label[0]]}\"')\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VgWICqdqQNaQ"
      },
      "source": [
        "You can see that your model very clearly recognized the audio command as \"no.\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J3jF933m9z1J"
      },
      "source": [
        "## Next steps\n",
        "\n",
        "This tutorial showed how you could do simple audio classification using a convolutional neural network with TensorFlow and Python.\n",
        "\n",
        "* To learn how to use transfer learning for audio classification, check out the [Sound classification with YAMNet](https://www.tensorflow.org/hub/tutorials/yamnet) tutorial.\n",
        "\n",
        "* To build your own interactive web app for audio classification, consider taking the [TensorFlow.js - Audio recognition using transfer learning codelab](https://codelabs.developers.google.com/codelabs/tensorflowjs-audio-codelab/index.html#0).\n",
        "\n",
        "* TensorFlow also has additional support for [audio data preparation and augmentation](https://www.tensorflow.org/io/tutorials/audio) to help with your own audio-based projects.\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "simple_audio.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
